from functools import partial
from typing import Any, Callable, Sequence, Optional
#from flax import linen as nn
from jax import jit
from jax import numpy as jnp
from jax import random
from jax.example_libraries import optimizers
import neural_tangents as nt
from jax import grad, value_and_grad
import matplotlib.pyplot as plt
#import keras
from neural_tangents import stax
import wandb
import numpy as np
import argparse
import os



parser = argparse.ArgumentParser(description=''
    '''
    tests of multiple-layer-per-block models
    ''', formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--mom', default=0.0, type=float, help='momentum')
parser.add_argument('--wd', default=0.0, type=float, help='weight decay')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--arch', type=str, default='NT_CNN')
parser.add_argument('--dataset', type=str,default = 'cifar5m')
parser.add_argument('--optimizer', default='sgd')
parser.add_argument('--width_mult', type=float, default=32)
parser.add_argument('--depth_mult', type=int, default = 1)
#parser.add_argument('--beta', type=float, default=1.0,
#                       help='scaling factor for the residual branch. To use together with res_scaling parameter')
parser.add_argument('--gamma_zero', type=float, default=1.0,
                         help='controls the amount of feature learning.')
parser.add_argument('--parametr', type=str, default='linearize') # options {linearize, MF}
parser.add_argument('--epochs',type=int, default = 1)
parser.add_argument('--save_model', action='store_true')
parser.add_argument('--num_workers',type=int,default=0)
parser.add_argument('--test_num_workers',type=int,default=0)
parser.add_argument('--num_ens',type=int,default=4)
parser.add_argument('--data_mult', default = 1.0, type = float)
parser.add_argument('--base_width',default = 128, type = float) # for the MF gamma scaling



args = parser.parse_args()

num_ens = args.num_ens

# pass in data directory

def load_CIFAR5M(num_files = 4, data_mult = 1.0, num_val = 2048, data_dir = ''):
    
    #data_dir = ' '
    print(f"reading cifar 5m data from {data_dir}")
    total_size = 0
    for i in range(num_files):
        file_name = f"{data_dir}/cifar5m_part{i}.npz"
        curr_data = np.load(file_name)
        arr = [curr_data[k] for k in curr_data.keys()]
            
        X, y = arr
        total_size += y.shape[0]
        
    train_size = int(total_size*data_mult)
    print(f"train size = {train_size}")
        
        
    data_x = np.zeros((train_size, 32, 32, 3), dtype=np.uint8)
    data_y = np.zeros((train_size,), dtype=int)
    assert data_x.shape[0] == data_y.shape[0]
        
    idx = 0 
    for i in range(num_files):
        print("reading file %d" % i)
        file_name = f"{data_dir}/cifar5m_part{i}.npz"
        curr_data = np.load(file_name)
        arr = [curr_data[k] for k in curr_data.keys()]
            
        X, y = arr
        len_i = y.shape[0]
        if idx+len_i <= train_size:
            data_x[idx:idx+len_i, ...] = X[:len_i, ...]
            data_y[idx:idx+len_i] = y[:len_i]
        else:
            break
        idx += len_i
    
    
    #x_train = (data_x/255.0- 0.5)/0.5
    x_train = data_x
    y_train = 1.0 * (data_y >=2)*(data_y <= 7) + -1.0 * (data_y<=1) - 1.0*(data_y>=8)
    
    
    # get test data from final file
    file_name = f"{data_dir}/cifar5m_part{5}.npz"
    curr_data = np.load(file_name)
    arr = [curr_data[k] for k in curr_data.keys()]
            
    X, y = arr
            
    data_x = X[:num_val]
    data_y = y[:num_val]
        

    x_test = (data_x/255.0- 0.5)/0.5
    y_test = 1.0 * (data_y >=2)*(data_y <= 7) + -1.0 * (data_y<=1) - 1.0*(data_y>=8)
    
    y_train = y_train.reshape((y_train.shape[0],1))
    y_test = y_test.reshape((y_test.shape[0],1))
    
    return x_train, y_train, x_test, y_test
    
    
x_train, y_train, x_test, y_test = load_CIFAR5M(num_files = 5)

_ModuleDef = Any


def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = stax.serial(
      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
  Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
      channels, (3, 3), strides, padding='SAME')
  return stax.serial(stax.FanOut(2),
                     stax.parallel(Main, Shortcut),
                     stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1), mismatch = True):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=mismatch)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
  return stax.serial(
      stax.Conv(k, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, k, mismatch = False),
      #WideResnetGroup(block_size, k, mismatch = False),
      #WideResnetGroup(block_size, k, (2, 2)),
      #WideResnetGroup(block_size, k, (2, 2)),
      stax.AvgPool((32, 32)),
      stax.Flatten(),
      stax.Dense(num_classes, 1., 0.))


def WideResnet_Deep(block_size, k, num_classes):
  return stax.serial(
      stax.Conv(k, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, k, mismatch = False),
      WideResnetGroup(block_size, k, mismatch = False),
      WideResnetGroup(block_size, k, (2, 2)),
      #WideResnetGroup(block_size, k, (2, 2)),
      stax.AvgPool((16, 16)),
      stax.Flatten(),
      stax.Dense(num_classes, 1., 0.))




def get_run_name_NT(args):
    return "ens_model_{}/dataset_{}/data_mult_{}/lr_{:.4f}/mom_{:.2f}/wd_{:.4f}/batch_size_{}/epochs_{}/width_mult_{}/depth_mult_{}/parametr_{}/gamma_zero_{}".format(args.arch, args.dataset, args.data_mult, args.lr, args.mom,args.wd,args.batch_size, args.epochs, args.width_mult, args.depth_mult, args.parametr, args.gamma_zero)

# set save directory 
save_dir = ''

run_name = get_run_name_NT(args)
# weights and biases
wandb.init(project="scaling_NL",# track hyperparameters and run metadata
            config=args.__dict__)
wandb.run.name = run_name



# choose number of batches to match epoch count
T = int(args.epochs * x_train.shape[0] / args.batch_size)
batchsize = args.batch_size

if args.arch == 'NT_CNN_Deep':
    init_fn, apply_fn, kernel_fn = WideResnet_Deep(block_size=args.depth_mult, k=int(args.width_mult), num_classes=1)
else:
    init_fn, apply_fn, kernel_fn = WideResnet(block_size=args.depth_mult, k=int(args.width_mult), num_classes=1)

if args.parametr == 'MF':
    opt_init, opt_update, get_params = optimizers.sgd(args.gamma_zero**2 * args.width_mult / args.base_width * args.lr)
else:
    opt_init, opt_update, get_params = optimizers.sgd(args.lr)

save_path = os.path.join(save_dir, run_name.replace("/", "-"))

log_every = 100
save_every = 1000

for e in range(num_ens):
        losses_e = []
        
        _, params = init_fn(random.PRNGKey(e),  input_shape = x_train[0:3].shape)
        opt = opt_init(params)
        
        if args.parametr == 'MF':
            apply_shift = jit( lambda p, X: (apply_fn(p, X) - apply_fn(params,X)) / jnp.sqrt(args.gamma_zero**2 * args.width_mult / args.base_width)  )
            
        else:
            apply_lin = nt.linearize(apply_fn, params)
            apply_shift = jit( lambda p, X: apply_lin(p, X) - apply_lin(params,X) )
        
        
        #apply_shift = jit( lambda p, X: apply_lin(p, X) )
        loss = jit( lambda p,X,y: jnp.mean( ( y - apply_shift(p,X) )**2 ) )
        val_grad_fn = jit( value_and_grad(loss) )       
        #grad_fn = jit(grad(loss,0))
        
        for t in range(T):
            ind = (batchsize * t) % x_train.shape[0] 
            Xt = jnp.array(x_train[ind:ind+batchsize])
            Xt = (Xt/255.0-0.5)/0.5
            yt = jnp.array(y_train[ind:ind+batchsize])
        
            if t % log_every == 0:
                loss_t = loss(get_params(opt), x_test, y_test)
                losses_e += [ loss_t ]
                if t > 0:
                    wandb.log({'train_loss': train_curr, 'test_loss': loss_t})
                train_curr = 0.0
            
            if t % save_every == 0:
                np.save(save_path + f'_test_losses_e_{e}.npy', np.array(losses_e))
        
            loss_tr, grad_t = val_grad_fn(get_params(opt), Xt, yt)
            train_curr += 1/log_every * loss_tr
            opt = opt_update(t, grad_t , opt)
            
        np.save(save_path + f'_test_losses_e_{e}.npy', np.array(losses_e))

        
